﻿#region Copyright and License
// Copyright 2010..2016 Alexander Reinert
// 
// This file is part of the ARSoft.Tools.Net - C# DNS client/server and SPF Library (http://arsofttoolsnet.codeplex.com/)
// 
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// 
//   http://www.apache.org/licenses/LICENSE-2.0
// 
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#endregion

using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
using Org.BouncyCastle.Security;

namespace ARSoft.Tools.Net.Dns
{
    /// <summary>
    ///   Class representing a DNS zone
    /// </summary>
    public class Zone : ICollection<DnsRecordBase>
    {
        private static readonly Regex _commentRemoverRegex = new Regex(@"^(?<data>(\\\""|[^\""]|(?<!\\)\"".*?(?<!\\)\"")*?)(;.*)?$", RegexOptions.Compiled | RegexOptions.ExplicitCapture);
        private static readonly Regex _lineSplitterRegex = new Regex("([^\\s\"]+)|\"(.*?(?<!\\\\))\"", RegexOptions.Compiled);

        private readonly List<DnsRecordBase> _records;

        /// <summary>
        ///   Gets the name of the Zone
        /// </summary>
        public DomainName Name { get; }

        /// <summary>
        ///   Creates a new instance of the Zone class with no records
        /// </summary>
        /// <param name="name">The name of the zone</param>
        public Zone(DomainName name)
        {
            Name = name;
            _records = new List<DnsRecordBase>();
        }

        /// <summary>
        ///   Creates a new instance of the Zone class that contains records copied from the specified collection
        /// </summary>
        /// <param name="name">The name of the zone</param>
        /// <param name="collection">Collection of records which are copied to the new Zone instance</param>
        public Zone(DomainName name, IEnumerable<DnsRecordBase> collection)
        {
            Name = name;
            _records = new List<DnsRecordBase>(collection);
        }

        /// <summary>
        ///   Create a new instance of the Zone class with the specified initial capacity
        /// </summary>
        /// <param name="name">The name of the zone</param>
        /// <param name="capacity">The initial capacity for the new Zone instance</param>
        public Zone(DomainName name, int capacity)
        {
            Name = name;
            _records = new List<DnsRecordBase>(capacity);
        }

        /// <summary>
        ///   Loads a Zone from a master file
        /// </summary>
        /// <param name="name">The name of the zone</param>
        /// <param name="zoneFile">Path to the Zone master file</param>
        /// <returns>A new instance of the Zone class</returns>
        public static Zone ParseMasterFile(DomainName name, string zoneFile)
        {
            using (var fs = File.Open(zoneFile, FileMode.Open))
            {
                using (StreamReader reader = new StreamReader(fs))
                {
                    return ParseMasterFile(name, reader);
                }
            }
        }

        /// <summary>
        ///   Loads a Zone from a master data stream
        /// </summary>
        /// <param name="name">The name of the zone</param>
        /// <param name="zoneFile">Stream containing the zone master data</param>
        /// <returns>A new instance of the Zone class</returns>
        public static Zone ParseMasterFile(DomainName name, Stream zoneFile)
        {
            using (StreamReader reader = new StreamReader(zoneFile))
            {
                return ParseMasterFile(name, reader);
            }
        }

        private static Zone ParseMasterFile(DomainName name, StreamReader reader)
        {
            List<DnsRecordBase> records = ParseRecords(reader, name, 0, new UnknownRecord(name, RecordType.Invalid, RecordClass.INet, 0, new byte[] { }));

            SoaRecord soa = (SoaRecord)records.SingleOrDefault(x => x.RecordType == RecordType.Soa);

            if (soa != null)
            {
                records.ForEach(x =>
                {
                    if (x.TimeToLive == 0)
                        x.TimeToLive = soa.NegativeCachingTTL;
                });
            }

            return new Zone(name, records);
        }

        private static List<DnsRecordBase> ParseRecords(StreamReader reader, DomainName origin, int ttl, DnsRecordBase lastRecord)
        {
            List<DnsRecordBase> records = new List<DnsRecordBase>();

            while (!reader.EndOfStream)
            {
                string line = ReadRecordLine(reader);

                if (!String.IsNullOrEmpty(line))
                {
                    string[] parts = _lineSplitterRegex.Matches(line).Cast<Match>().Select(x => x.Groups.Cast<Group>().Last(g => g.Success).Value.FromMasterfileLabelRepresentation()).ToArray();

                    if (parts[0].Equals("$origin", StringComparison.OrdinalIgnoreCase))
                    {
                        origin = DomainName.ParseFromMasterfile(parts[1]);
                    }
                    if (parts[0].Equals("$ttl", StringComparison.OrdinalIgnoreCase))
                    {
                        ttl = Int32.Parse(parts[1]);
                    }
                    if (parts[0].Equals("$include", StringComparison.OrdinalIgnoreCase))
                    {
                        FileStream fileStream = reader.BaseStream as FileStream;

                        if (fileStream == null)
                            throw new NotSupportedException("Includes only supported when loading files");

                        // ReSharper disable once AssignNullToNotNullAttribute
                        string path = Path.Combine(new FileInfo(fileStream.Name).DirectoryName, parts[1]);

                        DomainName includeOrigin = (parts.Length > 2) ? DomainName.ParseFromMasterfile(parts[2]) : origin;
                        using (var fs = File.Open(path, FileMode.Open))
                        {
                            using (StreamReader includeReader = new StreamReader(fs))
                            {
                                records.AddRange(ParseRecords(includeReader, includeOrigin, ttl, lastRecord));
                            }
                        }
                    }
                    else
                    {
                        string domainString;
                        RecordType recordType;
                        RecordClass recordClass;
                        int recordTtl;
                        string[] rrData;

                        if (Int32.TryParse(parts[0], out recordTtl))
                        {
                            // no domain, starts with ttl
                            if (RecordClassHelper.TryParseShortString(parts[1], out recordClass, false))
                            {
                                // second is record class
                                domainString = null;
                                recordType = RecordTypeHelper.ParseShortString(parts[2]);
                                rrData = parts.Skip(3).ToArray();
                            }
                            else
                            {
                                // no record class
                                domainString = null;
                                recordClass = RecordClass.Invalid;
                                recordType = RecordTypeHelper.ParseShortString(parts[1]);
                                rrData = parts.Skip(2).ToArray();
                            }
                        }
                        else if (RecordClassHelper.TryParseShortString(parts[0], out recordClass, false))
                        {
                            // no domain, starts with record class
                            if (Int32.TryParse(parts[1], out recordTtl))
                            {
                                // second is ttl
                                domainString = null;
                                recordType = RecordTypeHelper.ParseShortString(parts[2]);
                                rrData = parts.Skip(3).ToArray();
                            }
                            else
                            {
                                // no ttl
                                recordTtl = 0;
                                domainString = null;
                                recordType = RecordTypeHelper.ParseShortString(parts[1]);
                                rrData = parts.Skip(2).ToArray();
                            }
                        }
                        else if (RecordTypeHelper.TryParseShortString(parts[0], out recordType))
                        {
                            // no domain, start with record type
                            recordTtl = 0;
                            recordClass = RecordClass.Invalid;
                            domainString = null;
                            rrData = parts.Skip(2).ToArray();
                        }
                        else
                        {
                            domainString = parts[0];

                            if (Int32.TryParse(parts[1], out recordTtl))
                            {
                                // domain, second is ttl
                                if (RecordClassHelper.TryParseShortString(parts[2], out recordClass, false))
                                {
                                    // third is record class
                                    recordType = RecordTypeHelper.ParseShortString(parts[3]);
                                    rrData = parts.Skip(4).ToArray();
                                }
                                else
                                {
                                    // no record class
                                    recordClass = RecordClass.Invalid;
                                    recordType = RecordTypeHelper.ParseShortString(parts[2]);
                                    rrData = parts.Skip(3).ToArray();
                                }
                            }
                            else if (RecordClassHelper.TryParseShortString(parts[1], out recordClass, false))
                            {
                                // domain, second is record class
                                if (Int32.TryParse(parts[2], out recordTtl))
                                {
                                    // third is ttl
                                    recordType = RecordTypeHelper.ParseShortString(parts[3]);
                                    rrData = parts.Skip(4).ToArray();
                                }
                                else
                                {
                                    // no ttl
                                    recordTtl = 0;
                                    recordType = RecordTypeHelper.ParseShortString(parts[2]);
                                    rrData = parts.Skip(3).ToArray();
                                }
                            }
                            else
                            {
                                // domain with record type
                                recordType = RecordTypeHelper.ParseShortString(parts[1]);
                                recordTtl = 0;
                                recordClass = RecordClass.Invalid;
                                rrData = parts.Skip(2).ToArray();
                            }
                        }

                        DomainName domain;
                        if (String.IsNullOrEmpty(domainString))
                        {
                            domain = lastRecord.Name;
                        }
                        else if (domainString == "@")
                        {
                            domain = origin;
                        }
                        else if (domainString.EndsWith("."))
                        {
                            domain = DomainName.ParseFromMasterfile(domainString);
                        }
                        else
                        {
                            domain = DomainName.ParseFromMasterfile(domainString) + origin;
                        }

                        if (recordClass == RecordClass.Invalid)
                        {
                            recordClass = lastRecord.RecordClass;
                        }

                        if (recordType == RecordType.Invalid)
                        {
                            recordType = lastRecord.RecordType;
                        }

                        if (recordTtl == 0)
                        {
                            recordTtl = ttl;
                        }
                        else
                        {
                            ttl = recordTtl;
                        }

                        lastRecord = DnsRecordBase.Create(recordType);
                        lastRecord.RecordType = recordType;
                        lastRecord.Name = domain;
                        lastRecord.RecordClass = recordClass;
                        lastRecord.TimeToLive = recordTtl;

                        if ((rrData.Length > 0) && (rrData[0] == @"\#"))
                        {
                            lastRecord.ParseUnknownRecordData(rrData);
                        }
                        else
                        {
                            lastRecord.ParseRecordData(origin, rrData);
                        }

                        records.Add(lastRecord);
                    }
                }
            }

            return records;
        }

        private static string ReadRecordLine(StreamReader reader)
        {
            string line = ReadLineWithoutComment(reader);

            int bracketPos;
            if ((bracketPos = line.IndexOf('(')) != -1)
            {
                StringBuilder sb = new StringBuilder();

                sb.Append(line.Substring(0, bracketPos));
                sb.Append(" ");
                sb.Append(line.Substring(bracketPos + 1));

                while (true)
                {
                    sb.Append(" ");

                    line = ReadLineWithoutComment(reader);

                    if ((bracketPos = line.IndexOf(')')) == -1)
                    {
                        sb.Append(line);
                    }
                    else
                    {
                        sb.Append(line.Substring(0, bracketPos));
                        sb.Append(" ");
                        sb.Append(line.Substring(bracketPos + 1));
                        line = sb.ToString();
                        break;
                    }
                }
            }

            return line;
        }

        private static string ReadLineWithoutComment(StreamReader reader)
        {
            string line = reader.ReadLine();
            // ReSharper disable once AssignNullToNotNullAttribute
            return _commentRemoverRegex.Match(line).Groups["data"].Value;
        }

        /// <summary>
        ///   Signs a zone
        /// </summary>
        /// <param name="keys">A list of keys to sign the zone</param>
        /// <param name="inception">The inception date of the signatures</param>
        /// <param name="expiration">The expiration date of the signatures</param>
        /// <param name="nsec3Algorithm">The NSEC3 algorithm (or 0 when NSEC should be used)</param>
        /// <param name="nsec3Iterations">The number of iterations when NSEC3 is used</param>
        /// <param name="nsec3Salt">The salt when NSEC3 is used</param>
        /// <param name="nsec3OptOut">true, of NSEC3 OptOut should be used for delegations without DS record</param>
        /// <returns>A signed zone</returns>
        public Zone Sign(List<DnsKeyRecord> keys, DateTime inception, DateTime expiration, NSec3HashAlgorithm nsec3Algorithm = 0, int nsec3Iterations = 10, byte[] nsec3Salt = null, bool nsec3OptOut = false)
        {
            if ((keys == null) || (keys.Count == 0))
                throw new Exception("No DNS Keys were provided");

            if (!keys.All(x => x.IsZoneKey))
                throw new Exception("No DNS key with Zone Key Flag were provided");

            if (keys.Any(x => (x.PrivateKey == null) || (x.PrivateKey.Length == 0)))
                throw new Exception("For at least one DNS key no Private Key was provided");

            if (keys.Any(x => (x.Protocol != 3) || ((nsec3Algorithm != 0) ? !x.Algorithm.IsCompatibleWithNSec3() : !x.Algorithm.IsCompatibleWithNSec())))
                throw new Exception("At least one invalid DNS key was provided");

            List<DnsKeyRecord> keySigningKeys = keys.Where(x => x.IsSecureEntryPoint).ToList();
            List<DnsKeyRecord> zoneSigningKeys = keys.Where(x => !x.IsSecureEntryPoint).ToList();

            if (nsec3Algorithm == 0)
            {
                return SignWithNSec(inception, expiration, zoneSigningKeys, keySigningKeys);
            }
            else
            {
                return SignWithNSec3(inception, expiration, zoneSigningKeys, keySigningKeys, nsec3Algorithm, nsec3Iterations, nsec3Salt, nsec3OptOut);
            }
        }

        private Zone SignWithNSec(DateTime inception, DateTime expiration, List<DnsKeyRecord> zoneSigningKeys, List<DnsKeyRecord> keySigningKeys)
        {
            var soaRecord = _records.OfType<SoaRecord>().First();
            var subZones = _records.Where(x => (x.RecordType == RecordType.Ns) && (x.Name != Name)).Select(x => x.Name).Distinct().ToList();
            var glueRecords = _records.Where(x => subZones.Any(y => x.Name.IsSubDomainOf(y))).ToList();
            var recordsByName = _records.Except(glueRecords).Union(zoneSigningKeys).Union(keySigningKeys).GroupBy(x => x.Name).Select(x => new Tuple<DomainName, List<DnsRecordBase>>(x.Key, x.OrderBy(y => y.RecordType == RecordType.Soa ? -1 : (int)y.RecordType).ToList())).OrderBy(x => x.Item1).ToList();

            Zone res = new Zone(Name, Count * 3);

            for (int i = 0; i < recordsByName.Count; i++)
            {
                List<RecordType> recordTypes = new List<RecordType>();

                DomainName currentName = recordsByName[i].Item1;

                foreach (var recordsByType in recordsByName[i].Item2.GroupBy(x => x.RecordType))
                {
                    List<DnsRecordBase> records = recordsByType.ToList();

                    recordTypes.Add(recordsByType.Key);
                    res.AddRange(records);

                    // do not sign nameserver delegations for sub zones
                    if ((records[0].RecordType == RecordType.Ns) && (currentName != Name))
                        continue;

                    recordTypes.Add(RecordType.RrSig);

                    foreach (var key in zoneSigningKeys)
                    {
                        res.Add(new RrSigRecord(records, key, inception, expiration));
                    }
                    if (records[0].RecordType == RecordType.DnsKey)
                    {
                        foreach (var key in keySigningKeys)
                        {
                            res.Add(new RrSigRecord(records, key, inception, expiration));
                        }
                    }
                }

                recordTypes.Add(RecordType.NSec);

                NSecRecord nsecRecord = new NSecRecord(recordsByName[i].Item1, soaRecord.RecordClass, soaRecord.NegativeCachingTTL, recordsByName[(i + 1) % recordsByName.Count].Item1, recordTypes);
                res.Add(nsecRecord);

                foreach (var key in zoneSigningKeys)
                {
                    res.Add(new RrSigRecord(new List<DnsRecordBase>() { nsecRecord }, key, inception, expiration));
                }
            }

            res.AddRange(glueRecords);

            return res;
        }

        private Zone SignWithNSec3(DateTime inception, DateTime expiration, List<DnsKeyRecord> zoneSigningKeys, List<DnsKeyRecord> keySigningKeys, NSec3HashAlgorithm nsec3Algorithm, int nsec3Iterations, byte[] nsec3Salt, bool nsec3OptOut)
        {
            var soaRecord = _records.OfType<SoaRecord>().First();
            var subZoneNameserver = _records.Where(x => (x.RecordType == RecordType.Ns) && (x.Name != Name)).ToList();
            var subZones = subZoneNameserver.Select(x => x.Name).Distinct().ToList();
            var unsignedRecords = _records.Where(x => subZones.Any(y => x.Name.IsSubDomainOf(y))).ToList(); // glue records
            if (nsec3OptOut)
                unsignedRecords = unsignedRecords.Union(subZoneNameserver.Where(x => !_records.Any(y => (y.RecordType == RecordType.Ds) && (y.Name == x.Name)))).ToList(); // delegations without DS record
            var recordsByName = _records.Except(unsignedRecords).Union(zoneSigningKeys).Union(keySigningKeys).GroupBy(x => x.Name).Select(x => new Tuple<DomainName, List<DnsRecordBase>>(x.Key, x.OrderBy(y => y.RecordType == RecordType.Soa ? -1 : (int)y.RecordType).ToList())).OrderBy(x => x.Item1).ToList();

            byte nsec3RecordFlags = (byte)(nsec3OptOut ? 1 : 0);

            Zone res = new Zone(Name, Count * 3);
            List<NSec3Record> nSec3Records = new List<NSec3Record>(Count);

            if (nsec3Salt == null)
                nsec3Salt = SecureRandom.GetSeed(8);

            recordsByName[0].Item2.Add(new NSec3ParamRecord(soaRecord.Name, soaRecord.RecordClass, 0, nsec3Algorithm, 0, (ushort)nsec3Iterations, nsec3Salt));

            HashSet<DomainName> allNames = new HashSet<DomainName>();

            for (int i = 0; i < recordsByName.Count; i++)
            {
                List<RecordType> recordTypes = new List<RecordType>();

                DomainName currentName = recordsByName[i].Item1;

                foreach (var recordsByType in recordsByName[i].Item2.GroupBy(x => x.RecordType))
                {
                    List<DnsRecordBase> records = recordsByType.ToList();

                    recordTypes.Add(recordsByType.Key);
                    res.AddRange(records);

                    // do not sign nameserver delegations for sub zones
                    if ((records[0].RecordType == RecordType.Ns) && (currentName != Name))
                        continue;

                    recordTypes.Add(RecordType.RrSig);

                    foreach (var key in zoneSigningKeys)
                    {
                        res.Add(new RrSigRecord(records, key, inception, expiration));
                    }
                    if (records[0].RecordType == RecordType.DnsKey)
                    {
                        foreach (var key in keySigningKeys)
                        {
                            res.Add(new RrSigRecord(records, key, inception, expiration));
                        }
                    }
                }

                byte[] hash = recordsByName[i].Item1.GetNSec3Hash(nsec3Algorithm, nsec3Iterations, nsec3Salt);
                nSec3Records.Add(new NSec3Record(DomainName.ParseFromMasterfile(hash.ToBase32HexString()) + Name, soaRecord.RecordClass, soaRecord.NegativeCachingTTL, nsec3Algorithm, nsec3RecordFlags, (ushort)nsec3Iterations, nsec3Salt, hash, recordTypes));

                allNames.Add(currentName);
                for (int j = currentName.LabelCount - Name.LabelCount; j > 0; j--)
                {
                    DomainName possibleNonTerminal = currentName.GetParentName(j);

                    if (!allNames.Contains(possibleNonTerminal))
                    {
                        hash = possibleNonTerminal.GetNSec3Hash(nsec3Algorithm, nsec3Iterations, nsec3Salt);
                        nSec3Records.Add(new NSec3Record(DomainName.ParseFromMasterfile(hash.ToBase32HexString()) + Name, soaRecord.RecordClass, soaRecord.NegativeCachingTTL, nsec3Algorithm, nsec3RecordFlags, (ushort)nsec3Iterations, nsec3Salt, hash, new List<RecordType>()));

                        allNames.Add(possibleNonTerminal);
                    }
                }
            }

            nSec3Records = nSec3Records.OrderBy(x => x.Name).ToList();

            byte[] firstNextHashedOwnerName = nSec3Records[0].NextHashedOwnerName;

            for (int i = 1; i < nSec3Records.Count; i++)
            {
                nSec3Records[i - 1].NextHashedOwnerName = nSec3Records[i].NextHashedOwnerName;
            }

            nSec3Records[nSec3Records.Count - 1].NextHashedOwnerName = firstNextHashedOwnerName;

            foreach (var nSec3Record in nSec3Records)
            {
                res.Add(nSec3Record);

                foreach (var key in zoneSigningKeys)
                {
                    res.Add(new RrSigRecord(new List<DnsRecordBase>() { nSec3Record }, key, inception, expiration));
                }
            }

            res.AddRange(unsignedRecords);

            return res;
        }


        /// <summary>
        ///   Adds a record to the end of the Zone
        /// </summary>
        /// <param name="item">Record to be added</param>
        public void Add(DnsRecordBase item)
        {
            _records.Add(item);
        }

        /// <summary>
        ///   Adds an enumeration of records to the end of the Zone
        /// </summary>
        /// <param name="items">Records to be added</param>
        public void AddRange(IEnumerable<DnsRecordBase> items)
        {
            _records.AddRange(items);
        }

        /// <summary>
        ///   Removes all records from the zone
        /// </summary>
        public void Clear()
        {
            _records.Clear();
        }

        /// <summary>
        ///   Determines whether a record is in the Zone
        /// </summary>
        /// <param name="item">Item which should be searched</param>
        /// <returns>true, if the item is in the zone; otherwise, false</returns>
        public bool Contains(DnsRecordBase item)
        {
            return _records.Contains(item);
        }

        /// <summary>
        ///   Copies the entire Zone to a compatible array
        /// </summary>
        /// <param name="array">Array to which the records should be copied</param>
        /// <param name="arrayIndex">Starting index within the target array</param>
        public void CopyTo(DnsRecordBase[] array, int arrayIndex)
        {
            _records.CopyTo(array, arrayIndex);
        }

        /// <summary>
        ///   Gets the number of records actually contained in the Zone
        /// </summary>
        public int Count => _records.Count;

        /// <summary>
        ///   A value indicating whether the Zone is readonly
        /// </summary>
        /// <returns>false</returns>
        bool ICollection<DnsRecordBase>.IsReadOnly => false;

        /// <summary>
        ///   Removes a record from the Zone
        /// </summary>
        /// <param name="item">Item to be removed</param>
        /// <returns>true, if the record was removed from the Zone; otherwise, false</returns>
        public bool Remove(DnsRecordBase item)
        {
            return _records.Remove(item);
        }

        /// <summary>
        ///   Returns an enumerator that iterates through the records of the Zone
        /// </summary>
        /// <returns>An enumerator that iterates through the records of the Zone</returns>
        public IEnumerator<DnsRecordBase> GetEnumerator()
        {
            return _records.GetEnumerator();
        }

        /// <summary>
        ///   Returns an enumerator that iterates through the records of the Zone
        /// </summary>
        /// <returns>An enumerator that iterates through the records of the Zone</returns>
        IEnumerator IEnumerable.GetEnumerator()
        {
            return GetEnumerator();
        }
    }
}